# core/formalization/action_space.py
from enum import Enum
from abc import ABC, abstractmethod
from typing import Dict, List, Any, Optional
from .symbol_manager import SymbolManager

from llm.llm_wrapper import LLMWrapper
from utils.logger import Logger

class ActionType(Enum):
    FALLBACK = "fallback"
    SYMBOLIC_ABSTRACTION = "symbolic_abs"
    LOGICAL_ENCODING = "logic_encode"
    MATHEMATICAL_REPRESENTATION = "math_repr"
    DOMAIN_SPECIALIZATION = "domain_spec"
    METAPHORICAL_TRANSFORMATION = "metaphor"
    STRATEGIC_DECOMPOSITION = "strategic_decomp"    

@staticmethod
def get_action_index(actionType: ActionType) -> int:
    return list(ActionType).index(actionType)

@staticmethod
def get_action_type_by_index(idx: int) -> ActionType:
    return list(ActionType)[idx]

class FormalizationAction(ABC):

    def __init__(self, logger: Logger, llm: LLMWrapper, symbol_manager: SymbolManager):
        self.logger = logger
        self.llm = llm
        self.symbol_manager = symbol_manager

    @abstractmethod
    def get_type(self):
        pass

    @abstractmethod
    def should_apply(self, text: str, context: Dict[str, Any]=None) -> bool:
        pass

    @abstractmethod
    def apply(self, text: str, context: Dict[str, Any]=None) -> Dict[str, Any]:
        pass

class FormalizationActionSpace:

    def __init__(self, logger: Logger, llm: LLMWrapper):
        self.logger = logger
        self.llm = llm

        self.actions: Dict[ActionType, FormalizationAction] = {}

    def register_action(self, action: FormalizationAction):
        self.register_actions([action])

    def register_actions(self, actions: List[FormalizationAction]):
        self.logger.info("Formalization Action Space Register actions")
        for action in actions:
            action_type = action.get_type()

            if action_type in self.actions.keys():
                self.logger.warning(f"Action {action_type} has already existed")
                return

            self.actions[action_type] = action

    def get_actions_count(self):
        return len(self.actions.values())
    
    def get_all_actions(self) -> List[FormalizationAction]:
        return self.actions.values()
                
    def get_available_actions(self, text: str, context: Dict[str, Any]) -> List[ActionType]:
        available = []
        for action_type, action in self.actions.items():
            if action.should_apply(text, context):
                available.append(action_type)
        return available

    def apply_action(self, action_type: ActionType, text: str, context: Dict[str, Any]=None) -> Dict[str, Any]:
        if action_type not in self.actions:
            raise ValueError(f"Unknown action: {action_type}")
        
        action = self.actions[action_type]
        return action.apply(text, context)